Problem Statement

A care management organisation called WeCare wants to identify among its diabetic patients, the ones that are at high risk of getting re-admitted to the hospital. They wish to intervene by providing some incentive to these patients that will help them improve their health. As the star analyst of this organisation, your job is to identify high-risk diabetic patients through risk stratification. This will help the payer to decide what are the right intervention programs for these patients.

Importing the libraries and setting configs

In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style('whitegrid')

%matplotlib inline

pd.options.display.float_format = '{:.3f}'.format


import warnings
warnings.filterwarnings('ignore')

Utility functions

In [2]:
# Util fuction: line seperator
def print_ln():
    print('-' * 80, '\n')

Load the dataset and start Exploratory Data Analysis

In [3]:
# Loading in the dataset
diabetic_patient_data_orig = pd.read_csv('../resources/diabetic_data.csv')

# Create a working copy
diabetic_patient_data = diabetic_patient_data_orig.copy()

# Exploring the shape and info about the dataset
print('Dataframe Shape: ', diabetic_patient_data.shape)
print_ln()
print("Dataframe Info: \n")
diabetic_patient_data.info()
print_ln()
Dataframe Shape:  (101766, 50)
-------------------------------------------------------------------------------- 

Dataframe Info: 

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 101766 entries, 0 to 101765
Data columns (total 50 columns):
encounter_id                101766 non-null int64
patient_nbr                 101766 non-null int64
race                        101766 non-null object
gender                      101766 non-null object
age                         101766 non-null object
weight                      101766 non-null object
admission_type_id           101766 non-null int64
discharge_disposition_id    101766 non-null int64
admission_source_id         101766 non-null int64
time_in_hospital            101766 non-null int64
payer_code                  101766 non-null object
medical_specialty           101766 non-null object
num_lab_procedures          101766 non-null int64
num_procedures              101766 non-null int64
num_medications             101766 non-null int64
number_outpatient           101766 non-null int64
number_emergency            101766 non-null int64
number_inpatient            101766 non-null int64
diag_1                      101766 non-null object
diag_2                      101766 non-null object
diag_3                      101766 non-null object
number_diagnoses            101766 non-null int64
max_glu_serum               101766 non-null object
A1Cresult                   101766 non-null object
metformin                   101766 non-null object
repaglinide                 101766 non-null object
nateglinide                 101766 non-null object
chlorpropamide              101766 non-null object
glimepiride                 101766 non-null object
acetohexamide               101766 non-null object
glipizide                   101766 non-null object
glyburide                   101766 non-null object
tolbutamide                 101766 non-null object
pioglitazone                101766 non-null object
rosiglitazone               101766 non-null object
acarbose                    101766 non-null object
miglitol                    101766 non-null object
troglitazone                101766 non-null object
tolazamide                  101766 non-null object
examide                     101766 non-null object
citoglipton                 101766 non-null object
insulin                     101766 non-null object
glyburide-metformin         101766 non-null object
glipizide-metformin         101766 non-null object
glimepiride-pioglitazone    101766 non-null object
metformin-rosiglitazone     101766 non-null object
metformin-pioglitazone      101766 non-null object
change                      101766 non-null object
diabetesMed                 101766 non-null object
readmitted                  101766 non-null object
dtypes: int64(13), object(37)
memory usage: 38.8+ MB
-------------------------------------------------------------------------------- 

In [4]:
# Inspecting the head of the dataset
diabetic_patient_data.head(5)
Out[4]:
encounter_id patient_nbr race gender age weight admission_type_id discharge_disposition_id admission_source_id time_in_hospital ... citoglipton insulin glyburide-metformin glipizide-metformin glimepiride-pioglitazone metformin-rosiglitazone metformin-pioglitazone change diabetesMed readmitted
0 2278392 8222157 Caucasian Female [0-10) ? 6 25 1 1 ... No No No No No No No No No NO
1 149190 55629189 Caucasian Female [10-20) ? 1 1 7 3 ... No Up No No No No No Ch Yes >30
2 64410 86047875 AfricanAmerican Female [20-30) ? 1 1 7 2 ... No No No No No No No No Yes NO
3 500364 82442376 Caucasian Male [30-40) ? 1 1 7 2 ... No Up No No No No No Ch Yes NO
4 16680 42519267 Caucasian Male [40-50) ? 1 1 7 1 ... No Steady No No No No No Ch Yes NO

5 rows × 50 columns

Remove the redundant variables

In [4]:
# NOTE replace the `?` as `nan`

diabetic_patient_data = diabetic_patient_data.replace('?', np.nan)
diabetic_patient_data['medical_specialty'].replace({np.nan: 'Unknown'}, inplace=True)

# `encounter_id` is redundant for our purpose, so we can drop that as well
diabetic_patient_data = diabetic_patient_data.drop(['encounter_id'], axis=1)

# `patient_nbr` is redundant for our purpose, so we can drop that as well
diabetic_patient_data = diabetic_patient_data.drop(['patient_nbr'], axis=1)

# dropping the `diag` codes as well since they don't add direct value to the analysis
diabetic_patient_data = diabetic_patient_data.drop(['diag_1', 'diag_2', 'diag_3'], axis=1)

Checking and visualizing the missing values

In [5]:
# Analyse the missing values
columns_with_missing_data = round(100 * (diabetic_patient_data.isnull().sum() / len(diabetic_patient_data.index)), 2)
columns_with_missing_data[columns_with_missing_data > 30].plot(kind='bar')
plt.show()

Two columns have considerable data missing

  • weight
  • payer_code
In [6]:
# We can see that Weight column is almost completely empty and therefore can be dropped
diabetic_patient_data = diabetic_patient_data.drop(['weight'], axis=1)

# `payer_code` is redundant for our purpose, so we can drop that as well
diabetic_patient_data = diabetic_patient_data.drop(['payer_code'], axis=1)

Change the variable 'readmitted' to binary type by clubbing the values ">30" and "<30" as "YES".

In [7]:
diabetic_patient_data['readmitted'] = diabetic_patient_data['readmitted'].replace('>30', 'YES')
diabetic_patient_data['readmitted'] = diabetic_patient_data['readmitted'].replace('<30', 'YES')

Remove the duplicated rows and columns

In [8]:
diabetic_patient_data = diabetic_patient_data.drop_duplicates()

Checking whether data is imbalanced or not wrt target variable

In [9]:
readmitted_df = diabetic_patient_data["readmitted"].value_counts()

readmitted_df
Out[9]:
NO     54861
YES    46902
Name: readmitted, dtype: int64
In [10]:
diabetic_patient_data_rate = readmitted_df[1] / (readmitted_df[1] + readmitted_df[0])
diabetic_patient_data_rate
Out[10]:
0.4608944311783261
In [11]:
print("Total readmission Count     = {}".format(readmitted_df[1]))
print("Total Non-readmission Count = {}".format(readmitted_df[0]))
print("Readmission Rate            = {:.2f}%".format(diabetic_patient_data_rate*100))
print_ln()
Total readmission Count     = 46902
Total Non-readmission Count = 54861
Readmission Rate            = 46.09%
-------------------------------------------------------------------------------- 

We can see from above that the data is balanced and we can continue with the analysis.

Identifying the numerical and categorical features

In [12]:
def type_features(data):
    categorical_features = data.select_dtypes(include=["object"]).columns
    numerical_features = data.select_dtypes(exclude=["object"]).columns
    print("categorical_features :", categorical_features)
    print_ln()
    print("numerical_features:", numerical_features)
    print_ln()
    return categorical_features, numerical_features


diabetic_patient_data_cat_features, diabetic_patient_data_num_features = type_features(diabetic_patient_data)
categorical_features : Index(['race', 'gender', 'age', 'medical_specialty', 'max_glu_serum',
       'A1Cresult', 'metformin', 'repaglinide', 'nateglinide',
       'chlorpropamide', 'glimepiride', 'acetohexamide', 'glipizide',
       'glyburide', 'tolbutamide', 'pioglitazone', 'rosiglitazone', 'acarbose',
       'miglitol', 'troglitazone', 'tolazamide', 'examide', 'citoglipton',
       'insulin', 'glyburide-metformin', 'glipizide-metformin',
       'glimepiride-pioglitazone', 'metformin-rosiglitazone',
       'metformin-pioglitazone', 'change', 'diabetesMed', 'readmitted'],
      dtype='object')
-------------------------------------------------------------------------------- 

numerical_features: Index(['admission_type_id', 'discharge_disposition_id', 'admission_source_id',
       'time_in_hospital', 'num_lab_procedures', 'num_procedures',
       'num_medications', 'number_outpatient', 'number_emergency',
       'number_inpatient', 'number_diagnoses'],
      dtype='object')
-------------------------------------------------------------------------------- 

Perform basic data exploration for some numerical attributes

In [13]:
diabetic_patient_data_num_features = [
    'time_in_hospital',
    'num_lab_procedures',
    'num_procedures',
    'num_medications',
    'number_outpatient',
    'number_emergency',
    'number_inpatient',
    'number_diagnoses']

diabetic_patient_data_num_features_df = diabetic_patient_data[diabetic_patient_data_num_features]

diabetic_patient_data_num_features_df.describe()
Out[13]:
time_in_hospital num_lab_procedures num_procedures num_medications number_outpatient number_emergency number_inpatient number_diagnoses
count 101763.000 101763.000 101763.000 101763.000 101763.000 101763.000 101763.000 101763.000
mean 4.396 43.096 1.340 16.022 0.369 0.198 0.636 7.423
std 2.985 19.675 1.706 8.128 1.267 0.930 1.263 1.934
min 1.000 1.000 0.000 1.000 0.000 0.000 0.000 1.000
25% 2.000 31.000 0.000 10.000 0.000 0.000 0.000 6.000
50% 4.000 44.000 1.000 15.000 0.000 0.000 0.000 8.000
75% 6.000 57.000 2.000 20.000 0.000 0.000 1.000 9.000
max 14.000 132.000 6.000 81.000 42.000 76.000 21.000 16.000
In [15]:
diabetic_patient_data_num_features_df.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 101763 entries, 0 to 101765
Data columns (total 8 columns):
time_in_hospital      101763 non-null int64
num_lab_procedures    101763 non-null int64
num_procedures        101763 non-null int64
num_medications       101763 non-null int64
number_outpatient     101763 non-null int64
number_emergency      101763 non-null int64
number_inpatient      101763 non-null int64
number_diagnoses      101763 non-null int64
dtypes: int64(8)
memory usage: 7.0 MB

Univariate analysis of some numerical attributes

In [14]:
for a_num_feature in diabetic_patient_data_num_features:
    sns.FacetGrid(diabetic_patient_data, hue="readmitted", height=6).map(sns.distplot, a_num_feature).add_legend()
    plt.show()

Bivariate analysis of some numerical attributes

In [15]:
diabetic_patient_data_num_features_df = diabetic_patient_data[diabetic_patient_data_num_features]
diabetic_patient_data_num_features_df.head()
Out[15]:
time_in_hospital num_lab_procedures num_procedures num_medications number_outpatient number_emergency number_inpatient number_diagnoses
0 1 41 0 1 0 0 0 1
1 3 59 0 18 0 0 0 9
2 2 11 5 13 2 0 1 6
3 2 44 1 16 0 0 0 7
4 1 51 0 8 0 0 0 5
In [16]:
# Create correlation matrix
corr_matrix = diabetic_patient_data_num_features_df.corr().abs()


# plotting correlations on a heatmap

# figure size
plt.figure(figsize=(18,10))

# heatmap
sns.heatmap(corr_matrix, cmap="YlGnBu", annot=True)
plt.show()

Isolating highly correlated ( > 80 ) values from the dataset

In [17]:
# Select upper triangle of correlation matrix
upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(np.bool))

# Find index of feature columns with correlation greater than 0.80
high_corr_features = [column for column in upper.columns if any(upper[column] > 0.60)]

print("HIGHLY CORRELATED FEATURES IN DATA SET:{}\n\n{}".format(len(high_corr_features), high_corr_features))
HIGHLY CORRELATED FEATURES IN DATA SET:0

[]

We can see from above that there aren't any variables which are higly correlated

Perform basic data exploration for some categorical attributes

In [18]:
diabetic_patient_data_cat_features = [
                                      'admission_type_id',  # NOTE cat-encoded numerical values
#                                       'discharge_disposition_id',  # NOTE cat-encoded numerical values
#                                       'admission_source_id',  # NOTE cat-encoded numerical values
                                      'race',
                                      'gender',
                                      'age',  
                                      'medical_specialty',  
                                      'max_glu_serum',  # NOTE has low variance
                                      'A1Cresult',
                                      # diabetes-med-start 
                                      'metformin',
                                      'repaglinide',
                                      'nateglinide',
                                      'chlorpropamide',
                                      'glimepiride',
                                      'acetohexamide',
                                      'glipizide',
                                      'glyburide',
                                      'tolbutamide',
                                      'pioglitazone',
                                      'rosiglitazone',
                                      'acarbose',
                                      'miglitol',
                                      'troglitazone',
                                      'tolazamide',
                                      'examide',
                                      'citoglipton',
                                      'insulin',
                                      'glyburide-metformin',
                                      'glipizide-metformin',
                                      'glimepiride-pioglitazone',
                                      'metformin-rosiglitazone',
                                      'metformin-pioglitazone',
                                      # diabetes-med-end
                                      'change',
                                      'diabetesMed',
                                      'readmitted'
                                      ]

for a_cat_feat in diabetic_patient_data_cat_features:
    print(diabetic_patient_data[a_cat_feat].value_counts().count())
    print(diabetic_patient_data[a_cat_feat].value_counts())
    print_ln()
8
1    53988
3    18868
2    18480
6     5291
5     4785
8      320
7       21
4       10
Name: admission_type_id, dtype: int64
-------------------------------------------------------------------------------- 

5
Caucasian          76097
AfricanAmerican    19209
Hispanic            2037
Other               1506
Asian                641
Name: race, dtype: int64
-------------------------------------------------------------------------------- 

3
Female             54705
Male               47055
Unknown/Invalid        3
Name: gender, dtype: int64
-------------------------------------------------------------------------------- 

10
[70-80)     26068
[60-70)     22483
[50-60)     17255
[80-90)     17197
[40-50)      9685
[30-40)      3775
[90-100)     2793
[20-30)      1656
[10-20)       690
[0-10)        161
Name: age, dtype: int64
-------------------------------------------------------------------------------- 

73
Unknown                             49948
InternalMedicine                    14635
Emergency/Trauma                     7565
Family/GeneralPractice               7440
Cardiology                           5352
                                    ...  
Proctology                              1
Pediatrics-InfectiousDiseases           1
Neurophysiology                         1
Perinatology                            1
Surgery-PlasticwithinHeadandNeck        1
Name: medical_specialty, Length: 73, dtype: int64
-------------------------------------------------------------------------------- 

4
None    96417
Norm     2597
>200     1485
>300     1264
Name: max_glu_serum, dtype: int64
-------------------------------------------------------------------------------- 

4
None    84746
>8       8215
Norm     4990
>7       3812
Name: A1Cresult, dtype: int64
-------------------------------------------------------------------------------- 

4
No        81775
Steady    18346
Up         1067
Down        575
Name: metformin, dtype: int64
-------------------------------------------------------------------------------- 

4
No        100224
Steady      1384
Up           110
Down          45
Name: repaglinide, dtype: int64
-------------------------------------------------------------------------------- 

4
No        101060
Steady       668
Up            24
Down          11
Name: nateglinide, dtype: int64
-------------------------------------------------------------------------------- 

4
No        101677
Steady        79
Up             6
Down           1
Name: chlorpropamide, dtype: int64
-------------------------------------------------------------------------------- 

4
No        96572
Steady     4670
Up          327
Down        194
Name: glimepiride, dtype: int64
-------------------------------------------------------------------------------- 

2
No        101762
Steady         1
Name: acetohexamide, dtype: int64
-------------------------------------------------------------------------------- 

4
No        89077
Steady    11356
Up          770
Down        560
Name: glipizide, dtype: int64
-------------------------------------------------------------------------------- 

4
No        91113
Steady     9274
Up          812
Down        564
Name: glyburide, dtype: int64
-------------------------------------------------------------------------------- 

2
No        101740
Steady        23
Name: tolbutamide, dtype: int64
-------------------------------------------------------------------------------- 

4
No        94435
Steady     6976
Up          234
Down        118
Name: pioglitazone, dtype: int64
-------------------------------------------------------------------------------- 

4
No        95398
Steady     6100
Up          178
Down         87
Name: rosiglitazone, dtype: int64
-------------------------------------------------------------------------------- 

4
No        101455
Steady       295
Up            10
Down           3
Name: acarbose, dtype: int64
-------------------------------------------------------------------------------- 

4
No        101725
Steady        31
Down           5
Up             2
Name: miglitol, dtype: int64
-------------------------------------------------------------------------------- 

2
No        101760
Steady         3
Name: troglitazone, dtype: int64
-------------------------------------------------------------------------------- 

3
No        101724
Steady        38
Up             1
Name: tolazamide, dtype: int64
-------------------------------------------------------------------------------- 

1
No    101763
Name: examide, dtype: int64
-------------------------------------------------------------------------------- 

1
No    101763
Name: citoglipton, dtype: int64
-------------------------------------------------------------------------------- 

4
No        47381
Steady    30848
Down      12218
Up        11316
Name: insulin, dtype: int64
-------------------------------------------------------------------------------- 

4
No        101057
Steady       692
Up             8
Down           6
Name: glyburide-metformin, dtype: int64
-------------------------------------------------------------------------------- 

2
No        101750
Steady        13
Name: glipizide-metformin, dtype: int64
-------------------------------------------------------------------------------- 

2
No        101762
Steady         1
Name: glimepiride-pioglitazone, dtype: int64
-------------------------------------------------------------------------------- 

2
No        101761
Steady         2
Name: metformin-rosiglitazone, dtype: int64
-------------------------------------------------------------------------------- 

2
No        101762
Steady         1
Name: metformin-pioglitazone, dtype: int64
-------------------------------------------------------------------------------- 

2
No    54752
Ch    47011
Name: change, dtype: int64
-------------------------------------------------------------------------------- 

2
Yes    78362
No     23401
Name: diabetesMed, dtype: int64
-------------------------------------------------------------------------------- 

2
NO     54861
YES    46902
Name: readmitted, dtype: int64
-------------------------------------------------------------------------------- 

Data Preparation

Scaling the numerical features

In [19]:
# scaling the features
from sklearn.preprocessing import scale

for a_num_feat in diabetic_patient_data_num_features:
    diabetic_patient_data[a_num_feat] = pd.DataFrame(scale(diabetic_patient_data[a_num_feat]))
    
diabetic_patient_data.head()
    
Out[19]:
race gender age admission_type_id discharge_disposition_id admission_source_id time_in_hospital medical_specialty num_lab_procedures num_procedures ... citoglipton insulin glyburide-metformin glipizide-metformin glimepiride-pioglitazone metformin-rosiglitazone metformin-pioglitazone change diabetesMed readmitted
0 Caucasian Female [0-10) 6 25 1 -1.138 Pediatrics-Endocrinology -0.107 -0.785 ... No No No No No No No No No NO
1 Caucasian Female [10-20) 1 1 7 -0.468 Unknown 0.808 -0.785 ... No Up No No No No No Ch Yes YES
2 AfricanAmerican Female [20-30) 1 1 7 -0.803 Unknown -1.631 2.146 ... No No No No No No No No Yes NO
3 Caucasian Male [30-40) 1 1 7 -0.803 Unknown 0.046 -0.199 ... No Up No No No No No Ch Yes NO
4 Caucasian Male [40-50) 1 1 7 -1.138 Unknown 0.402 -0.785 ... No Steady No No No No No Ch Yes NO

5 rows × 43 columns

Create dummy variables for categorical ones

In [20]:
for a_cat_feat in diabetic_patient_data_cat_features:
    tmp = pd.get_dummies(diabetic_patient_data[a_cat_feat], prefix=a_cat_feat, drop_first=True)
    diabetic_patient_data = pd.concat([diabetic_patient_data, tmp], axis=1)
    diabetic_patient_data = diabetic_patient_data.drop([a_cat_feat], 1)


diabetic_patient_data.head()
Out[20]:
discharge_disposition_id admission_source_id time_in_hospital num_lab_procedures num_procedures num_medications number_outpatient number_emergency number_inpatient number_diagnoses ... glyburide-metformin_No glyburide-metformin_Steady glyburide-metformin_Up glipizide-metformin_Steady glimepiride-pioglitazone_Steady metformin-rosiglitazone_Steady metformin-pioglitazone_Steady change_No diabetesMed_Yes readmitted_YES
0 25 1 -1.138 -0.107 -0.785 -1.848 -0.291 -0.213 -0.503 -3.322 ... 1 0 0 0 0 0 0 1 0 0
1 1 7 -0.468 0.808 -0.785 0.243 -0.291 -0.213 -0.503 0.816 ... 1 0 0 0 0 0 0 0 1 1
2 1 7 -0.803 -1.631 2.146 -0.372 1.287 -0.213 0.289 -0.736 ... 1 0 0 0 0 0 0 1 1 0
3 1 7 -0.803 0.046 -0.199 -0.003 -0.291 -0.213 -0.503 -0.219 ... 1 0 0 0 0 0 0 0 1 0
4 1 7 -1.138 0.402 -0.785 -0.987 -0.291 -0.213 -0.503 -1.253 ... 1 0 0 0 0 0 0 0 1 0

5 rows × 161 columns

Engineered features in the dataset after the feature engineering stage

In [21]:
diabetic_patient_data.columns.to_list()
Out[21]:
['discharge_disposition_id',
 'admission_source_id',
 'time_in_hospital',
 'num_lab_procedures',
 'num_procedures',
 'num_medications',
 'number_outpatient',
 'number_emergency',
 'number_inpatient',
 'number_diagnoses',
 'admission_type_id_2',
 'admission_type_id_3',
 'admission_type_id_4',
 'admission_type_id_5',
 'admission_type_id_6',
 'admission_type_id_7',
 'admission_type_id_8',
 'race_Asian',
 'race_Caucasian',
 'race_Hispanic',
 'race_Other',
 'gender_Male',
 'gender_Unknown/Invalid',
 'age_[10-20)',
 'age_[20-30)',
 'age_[30-40)',
 'age_[40-50)',
 'age_[50-60)',
 'age_[60-70)',
 'age_[70-80)',
 'age_[80-90)',
 'age_[90-100)',
 'medical_specialty_Anesthesiology',
 'medical_specialty_Anesthesiology-Pediatric',
 'medical_specialty_Cardiology',
 'medical_specialty_Cardiology-Pediatric',
 'medical_specialty_DCPTEAM',
 'medical_specialty_Dentistry',
 'medical_specialty_Dermatology',
 'medical_specialty_Emergency/Trauma',
 'medical_specialty_Endocrinology',
 'medical_specialty_Endocrinology-Metabolism',
 'medical_specialty_Family/GeneralPractice',
 'medical_specialty_Gastroenterology',
 'medical_specialty_Gynecology',
 'medical_specialty_Hematology',
 'medical_specialty_Hematology/Oncology',
 'medical_specialty_Hospitalist',
 'medical_specialty_InfectiousDiseases',
 'medical_specialty_InternalMedicine',
 'medical_specialty_Nephrology',
 'medical_specialty_Neurology',
 'medical_specialty_Neurophysiology',
 'medical_specialty_Obsterics&Gynecology-GynecologicOnco',
 'medical_specialty_Obstetrics',
 'medical_specialty_ObstetricsandGynecology',
 'medical_specialty_Oncology',
 'medical_specialty_Ophthalmology',
 'medical_specialty_Orthopedics',
 'medical_specialty_Orthopedics-Reconstructive',
 'medical_specialty_Osteopath',
 'medical_specialty_Otolaryngology',
 'medical_specialty_OutreachServices',
 'medical_specialty_Pathology',
 'medical_specialty_Pediatrics',
 'medical_specialty_Pediatrics-AllergyandImmunology',
 'medical_specialty_Pediatrics-CriticalCare',
 'medical_specialty_Pediatrics-EmergencyMedicine',
 'medical_specialty_Pediatrics-Endocrinology',
 'medical_specialty_Pediatrics-Hematology-Oncology',
 'medical_specialty_Pediatrics-InfectiousDiseases',
 'medical_specialty_Pediatrics-Neurology',
 'medical_specialty_Pediatrics-Pulmonology',
 'medical_specialty_Perinatology',
 'medical_specialty_PhysicalMedicineandRehabilitation',
 'medical_specialty_PhysicianNotFound',
 'medical_specialty_Podiatry',
 'medical_specialty_Proctology',
 'medical_specialty_Psychiatry',
 'medical_specialty_Psychiatry-Addictive',
 'medical_specialty_Psychiatry-Child/Adolescent',
 'medical_specialty_Psychology',
 'medical_specialty_Pulmonology',
 'medical_specialty_Radiologist',
 'medical_specialty_Radiology',
 'medical_specialty_Resident',
 'medical_specialty_Rheumatology',
 'medical_specialty_Speech',
 'medical_specialty_SportsMedicine',
 'medical_specialty_Surgeon',
 'medical_specialty_Surgery-Cardiovascular',
 'medical_specialty_Surgery-Cardiovascular/Thoracic',
 'medical_specialty_Surgery-Colon&Rectal',
 'medical_specialty_Surgery-General',
 'medical_specialty_Surgery-Maxillofacial',
 'medical_specialty_Surgery-Neuro',
 'medical_specialty_Surgery-Pediatric',
 'medical_specialty_Surgery-Plastic',
 'medical_specialty_Surgery-PlasticwithinHeadandNeck',
 'medical_specialty_Surgery-Thoracic',
 'medical_specialty_Surgery-Vascular',
 'medical_specialty_SurgicalSpecialty',
 'medical_specialty_Unknown',
 'medical_specialty_Urology',
 'max_glu_serum_>300',
 'max_glu_serum_None',
 'max_glu_serum_Norm',
 'A1Cresult_>8',
 'A1Cresult_None',
 'A1Cresult_Norm',
 'metformin_No',
 'metformin_Steady',
 'metformin_Up',
 'repaglinide_No',
 'repaglinide_Steady',
 'repaglinide_Up',
 'nateglinide_No',
 'nateglinide_Steady',
 'nateglinide_Up',
 'chlorpropamide_No',
 'chlorpropamide_Steady',
 'chlorpropamide_Up',
 'glimepiride_No',
 'glimepiride_Steady',
 'glimepiride_Up',
 'acetohexamide_Steady',
 'glipizide_No',
 'glipizide_Steady',
 'glipizide_Up',
 'glyburide_No',
 'glyburide_Steady',
 'glyburide_Up',
 'tolbutamide_Steady',
 'pioglitazone_No',
 'pioglitazone_Steady',
 'pioglitazone_Up',
 'rosiglitazone_No',
 'rosiglitazone_Steady',
 'rosiglitazone_Up',
 'acarbose_No',
 'acarbose_Steady',
 'acarbose_Up',
 'miglitol_No',
 'miglitol_Steady',
 'miglitol_Up',
 'troglitazone_Steady',
 'tolazamide_Steady',
 'tolazamide_Up',
 'insulin_No',
 'insulin_Steady',
 'insulin_Up',
 'glyburide-metformin_No',
 'glyburide-metformin_Steady',
 'glyburide-metformin_Up',
 'glipizide-metformin_Steady',
 'glimepiride-pioglitazone_Steady',
 'metformin-rosiglitazone_Steady',
 'metformin-pioglitazone_Steady',
 'change_No',
 'diabetesMed_Yes',
 'readmitted_YES']

Final clean-up for the dataset before the model building

In [22]:
def clean_dataset(df):
    assert isinstance(df, pd.DataFrame), "df needs to be a pd.DataFrame"
    df.dropna(inplace=True)
    indices_to_keep = ~df.isin([np.nan, np.inf, -np.inf]).any(1)
    return df[indices_to_keep].astype(np.float64)


clean_dataset(diabetic_patient_data)
Out[22]:
discharge_disposition_id admission_source_id time_in_hospital num_lab_procedures num_procedures num_medications number_outpatient number_emergency number_inpatient number_diagnoses ... glyburide-metformin_No glyburide-metformin_Steady glyburide-metformin_Up glipizide-metformin_Steady glimepiride-pioglitazone_Steady metformin-rosiglitazone_Steady metformin-pioglitazone_Steady change_No diabetesMed_Yes readmitted_YES
0 25.000 1.000 -1.138 -0.107 -0.785 -1.848 -0.291 -0.213 -0.503 -3.322 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000 0.000
1 1.000 7.000 -0.468 0.808 -0.785 0.243 -0.291 -0.213 -0.503 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000
2 1.000 7.000 -0.803 -1.631 2.146 -0.372 1.287 -0.213 0.289 -0.736 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000 0.000
3 1.000 7.000 -0.803 0.046 -0.199 -0.003 -0.291 -0.213 -0.503 -0.219 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000
4 1.000 7.000 -1.138 0.402 -0.785 -0.987 -0.291 -0.213 -0.503 -1.253 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000
5 1.000 2.000 -0.468 -0.615 2.732 -0.003 -0.291 -0.213 -0.503 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000 1.000
6 1.000 2.000 -0.133 1.367 -0.199 0.612 -0.291 -0.213 -0.503 -0.219 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000
7 1.000 7.000 0.202 1.520 -0.785 -0.495 -0.291 -0.213 -0.503 0.299 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000 1.000
8 1.000 4.000 2.882 1.266 0.387 1.474 -0.291 -0.213 -0.503 0.299 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000
9 3.000 4.000 2.547 -0.513 0.973 0.243 -0.291 -0.213 -0.503 0.299 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000
10 1.000 7.000 1.542 0.198 0.387 0.120 -0.291 -0.213 -0.503 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000 1.000
11 1.000 4.000 0.872 0.961 -0.785 -0.618 -0.291 -0.213 -0.503 -0.219 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000
12 3.000 7.000 0.872 0.859 -0.785 -0.126 -0.291 0.862 -0.503 0.299 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000
13 6.000 7.000 1.877 0.605 -0.199 1.843 -0.291 -0.213 -0.503 0.299 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000 0.000
14 1.000 2.000 -1.138 0.300 2.146 -1.725 -0.291 -0.213 -0.503 0.299 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000 1.000
15 3.000 7.000 2.547 1.622 2.146 -0.372 -0.291 -0.213 -0.503 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000
16 1.000 7.000 -0.133 0.097 1.560 0.120 -0.291 -0.213 -0.503 0.299 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000
17 1.000 7.000 -0.468 -0.716 -0.785 -0.618 -0.291 -0.213 -0.503 -2.287 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000 0.000
18 1.000 7.000 0.202 -0.411 2.146 0.859 -0.291 -0.213 -0.503 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000 1.000
19 6.000 2.000 0.537 -0.056 0.387 0.859 -0.291 -0.213 -0.503 0.299 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000
20 1.000 4.000 -0.803 1.164 -0.199 0.366 -0.291 -0.213 -0.503 -0.219 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000
21 1.000 4.000 -0.803 -0.361 0.387 -0.618 -0.291 -0.213 -0.503 -0.736 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000
22 1.000 4.000 -0.803 0.198 -0.785 -0.495 -0.291 -0.213 -0.503 0.299 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000 0.000
23 6.000 1.000 2.212 -0.056 0.387 0.366 -0.291 -0.213 -0.503 0.299 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000 1.000
24 1.000 2.000 -0.468 -1.225 1.560 0.243 -0.291 -0.213 -0.503 -0.736 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000
25 1.000 7.000 -1.138 -0.513 -0.785 -1.110 -0.291 -0.213 -0.503 -2.287 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000 0.000
26 3.000 7.000 0.537 1.063 0.973 0.243 -0.291 -0.213 -0.503 -0.219 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000
27 1.000 1.000 -0.803 -0.920 0.387 -0.618 -0.291 -0.213 -0.503 -2.287 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000 1.000
28 1.000 2.000 1.877 0.503 -0.785 0.489 -0.291 -0.213 -0.503 -0.736 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000
29 2.000 7.000 0.202 0.453 -0.785 -0.249 -0.291 -0.213 -0.503 0.299 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
101733 1.000 7.000 1.207 0.554 -0.199 0.489 0.498 -0.213 0.289 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000
101734 1.000 7.000 -0.133 0.351 1.560 1.105 -0.291 -0.213 -0.503 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000 0.000
101735 6.000 7.000 -1.138 -0.310 -0.785 -0.741 -0.291 -0.213 1.080 -0.219 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000
101736 3.000 7.000 -1.138 -0.564 2.732 -0.249 1.287 -0.213 1.080 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000 1.000
101737 1.000 1.000 -1.138 -2.140 2.146 -0.987 -0.291 -0.213 -0.503 -1.770 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000
101738 1.000 7.000 -0.133 -1.123 -0.199 -0.987 -0.291 -0.213 0.289 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000 0.000
101739 1.000 7.000 -0.133 -0.005 -0.785 -0.741 -0.291 -0.213 -0.503 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000 0.000
101740 1.000 1.000 -0.468 0.097 -0.785 1.228 0.498 -0.213 0.289 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000
101741 3.000 7.000 -1.138 0.300 -0.785 -0.495 -0.291 -0.213 -0.503 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000 0.000
101742 6.000 2.000 -0.468 0.707 -0.785 -1.110 -0.291 0.862 -0.503 -2.287 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000
101743 13.000 7.000 -0.133 -2.089 -0.785 -1.110 0.498 -0.213 -0.503 -1.253 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000
101744 1.000 7.000 1.207 0.402 2.732 0.366 -0.291 -0.213 -0.503 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000 0.000
101745 1.000 1.000 3.217 1.317 -0.785 -0.003 -0.291 -0.213 -0.503 -1.253 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000
101746 1.000 7.000 -0.468 -0.818 -0.199 1.597 -0.291 0.862 -0.503 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000 1.000
101747 22.000 7.000 -0.468 -0.615 0.387 0.982 -0.291 -0.213 -0.503 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000 1.000
101748 4.000 7.000 2.882 1.723 2.732 6.026 -0.291 -0.213 -0.503 4.436 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000
101749 6.000 1.000 -0.468 -1.530 -0.199 -1.356 -0.291 -0.213 -0.503 0.299 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000
101750 1.000 1.000 2.882 0.402 0.387 -0.372 -0.291 -0.213 -0.503 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000
101751 6.000 1.000 1.542 0.351 0.387 2.089 -0.291 -0.213 -0.503 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000
101752 1.000 1.000 3.217 1.520 2.732 1.228 -0.291 0.862 -0.503 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000
101753 1.000 7.000 -0.803 0.148 2.732 0.120 0.498 0.862 0.289 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000
101754 1.000 7.000 0.202 -1.123 -0.199 -0.003 -0.291 -0.213 0.289 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000
101755 1.000 7.000 0.202 1.672 -0.199 0.736 -0.291 0.862 -0.503 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000
101756 1.000 7.000 -1.138 -2.140 -0.785 -0.126 2.076 -0.213 -0.503 -0.219 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000 1.000
101757 1.000 7.000 0.537 0.097 -0.199 1.105 2.076 0.862 1.080 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000 0.000
101758 1.000 7.000 -0.468 0.402 -0.785 -0.003 -0.291 -0.213 -0.503 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000
101759 1.000 7.000 0.202 -0.513 0.973 0.243 -0.291 -0.213 0.289 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000
101760 1.000 7.000 -1.138 0.503 -0.785 -0.864 0.498 -0.213 -0.503 2.885 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000
101761 3.000 7.000 1.877 0.097 0.387 0.612 -0.291 -0.213 0.289 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000
101762 4.000 5.000 0.537 -1.530 0.973 -1.602 -0.291 -0.213 -0.503 0.816 ... 1.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 1.000 0.000

101760 rows × 161 columns

Save the cleansed data

In [23]:
diabetic_patient_data.to_csv("../_resources/diabetic_patient_data_cleansed.csv", sep=',')

Model Building

Divide your data into training and testing dataset

In [24]:
# split into train and test
from sklearn.model_selection import train_test_split


X = diabetic_patient_data.loc[:, diabetic_patient_data.columns != 'readmitted_YES']
y = diabetic_patient_data.loc[:, 'readmitted_YES']


X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                    train_size=0.7,
                                                    test_size=0.3, random_state=100)

ML models appropriate for this classification problem

  • Decision tree
  • KNN

Decision Tree

In [25]:
# Importing decision tree classifier from sklearn library
from sklearn.tree import DecisionTreeClassifier

# Fitting the decision tree with default hyperparameters, apart from
# max_depth which is 5 so that we can plot and read the tree.
dt_default = DecisionTreeClassifier(max_depth=5)
dt_default.fit(X_train, y_train)
Out[25]:
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=5,
                       max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, presort=False,
                       random_state=None, splitter='best')
In [26]:
# Let's check the evaluation metrics of our default model

# Importing classification report and confusion matrix from sklearn metrics
from sklearn.metrics import classification_report, confusion_matrix

# Making predictions
y_pred_default = dt_default.predict(X_test)

# Printing classification report
print(classification_report(y_test, y_pred_default))
              precision    recall  f1-score   support

           0       0.59      0.67      0.63     16553
           1       0.53      0.44      0.48     13975

    accuracy                           0.57     30528
   macro avg       0.56      0.56      0.55     30528
weighted avg       0.56      0.57      0.56     30528

In [27]:
# Printing confusion matrix and accuracy
print(confusion_matrix(y_test,y_pred_default))
[[11093  5460]
 [ 7785  6190]]
In [28]:
# Confidence level per prediction

dt_pred_prob = dt_default.predict_proba(X_train)
dt_pred_prob
Out[28]:
array([[0.62029746, 0.37970254],
       [0.46863469, 0.53136531],
       [0.51388407, 0.48611593],
       ...,
       [0.46863469, 0.53136531],
       [0.57825046, 0.42174954],
       [0.54361294, 0.45638706]])

Visualizing the Decision Tree

In [29]:
# Importing required packages for visualization
from IPython.display import Image  
from sklearn.externals.six import StringIO  
from sklearn.tree import export_graphviz
import pydotplus, graphviz

# Putting features
features = list(diabetic_patient_data.columns[1:])
features
/Users/eklavya/.pyenv/versions/miniconda3-latest/lib/python3.7/site-packages/sklearn/externals/six.py:31: DeprecationWarning: The module is deprecated in version 0.21 and will be removed in version 0.23 since we've dropped support for Python 2.7. Please rely on the official version of six (https://pypi.org/project/six/).
  "(https://pypi.org/project/six/).", DeprecationWarning)
Out[29]:
['admission_source_id',
 'time_in_hospital',
 'num_lab_procedures',
 'num_procedures',
 'num_medications',
 'number_outpatient',
 'number_emergency',
 'number_inpatient',
 'number_diagnoses',
 'admission_type_id_2',
 'admission_type_id_3',
 'admission_type_id_4',
 'admission_type_id_5',
 'admission_type_id_6',
 'admission_type_id_7',
 'admission_type_id_8',
 'race_Asian',
 'race_Caucasian',
 'race_Hispanic',
 'race_Other',
 'gender_Male',
 'gender_Unknown/Invalid',
 'age_[10-20)',
 'age_[20-30)',
 'age_[30-40)',
 'age_[40-50)',
 'age_[50-60)',
 'age_[60-70)',
 'age_[70-80)',
 'age_[80-90)',
 'age_[90-100)',
 'medical_specialty_Anesthesiology',
 'medical_specialty_Anesthesiology-Pediatric',
 'medical_specialty_Cardiology',
 'medical_specialty_Cardiology-Pediatric',
 'medical_specialty_DCPTEAM',
 'medical_specialty_Dentistry',
 'medical_specialty_Dermatology',
 'medical_specialty_Emergency/Trauma',
 'medical_specialty_Endocrinology',
 'medical_specialty_Endocrinology-Metabolism',
 'medical_specialty_Family/GeneralPractice',
 'medical_specialty_Gastroenterology',
 'medical_specialty_Gynecology',
 'medical_specialty_Hematology',
 'medical_specialty_Hematology/Oncology',
 'medical_specialty_Hospitalist',
 'medical_specialty_InfectiousDiseases',
 'medical_specialty_InternalMedicine',
 'medical_specialty_Nephrology',
 'medical_specialty_Neurology',
 'medical_specialty_Neurophysiology',
 'medical_specialty_Obsterics&Gynecology-GynecologicOnco',
 'medical_specialty_Obstetrics',
 'medical_specialty_ObstetricsandGynecology',
 'medical_specialty_Oncology',
 'medical_specialty_Ophthalmology',
 'medical_specialty_Orthopedics',
 'medical_specialty_Orthopedics-Reconstructive',
 'medical_specialty_Osteopath',
 'medical_specialty_Otolaryngology',
 'medical_specialty_OutreachServices',
 'medical_specialty_Pathology',
 'medical_specialty_Pediatrics',
 'medical_specialty_Pediatrics-AllergyandImmunology',
 'medical_specialty_Pediatrics-CriticalCare',
 'medical_specialty_Pediatrics-EmergencyMedicine',
 'medical_specialty_Pediatrics-Endocrinology',
 'medical_specialty_Pediatrics-Hematology-Oncology',
 'medical_specialty_Pediatrics-InfectiousDiseases',
 'medical_specialty_Pediatrics-Neurology',
 'medical_specialty_Pediatrics-Pulmonology',
 'medical_specialty_Perinatology',
 'medical_specialty_PhysicalMedicineandRehabilitation',
 'medical_specialty_PhysicianNotFound',
 'medical_specialty_Podiatry',
 'medical_specialty_Proctology',
 'medical_specialty_Psychiatry',
 'medical_specialty_Psychiatry-Addictive',
 'medical_specialty_Psychiatry-Child/Adolescent',
 'medical_specialty_Psychology',
 'medical_specialty_Pulmonology',
 'medical_specialty_Radiologist',
 'medical_specialty_Radiology',
 'medical_specialty_Resident',
 'medical_specialty_Rheumatology',
 'medical_specialty_Speech',
 'medical_specialty_SportsMedicine',
 'medical_specialty_Surgeon',
 'medical_specialty_Surgery-Cardiovascular',
 'medical_specialty_Surgery-Cardiovascular/Thoracic',
 'medical_specialty_Surgery-Colon&Rectal',
 'medical_specialty_Surgery-General',
 'medical_specialty_Surgery-Maxillofacial',
 'medical_specialty_Surgery-Neuro',
 'medical_specialty_Surgery-Pediatric',
 'medical_specialty_Surgery-Plastic',
 'medical_specialty_Surgery-PlasticwithinHeadandNeck',
 'medical_specialty_Surgery-Thoracic',
 'medical_specialty_Surgery-Vascular',
 'medical_specialty_SurgicalSpecialty',
 'medical_specialty_Unknown',
 'medical_specialty_Urology',
 'max_glu_serum_>300',
 'max_glu_serum_None',
 'max_glu_serum_Norm',
 'A1Cresult_>8',
 'A1Cresult_None',
 'A1Cresult_Norm',
 'metformin_No',
 'metformin_Steady',
 'metformin_Up',
 'repaglinide_No',
 'repaglinide_Steady',
 'repaglinide_Up',
 'nateglinide_No',
 'nateglinide_Steady',
 'nateglinide_Up',
 'chlorpropamide_No',
 'chlorpropamide_Steady',
 'chlorpropamide_Up',
 'glimepiride_No',
 'glimepiride_Steady',
 'glimepiride_Up',
 'acetohexamide_Steady',
 'glipizide_No',
 'glipizide_Steady',
 'glipizide_Up',
 'glyburide_No',
 'glyburide_Steady',
 'glyburide_Up',
 'tolbutamide_Steady',
 'pioglitazone_No',
 'pioglitazone_Steady',
 'pioglitazone_Up',
 'rosiglitazone_No',
 'rosiglitazone_Steady',
 'rosiglitazone_Up',
 'acarbose_No',
 'acarbose_Steady',
 'acarbose_Up',
 'miglitol_No',
 'miglitol_Steady',
 'miglitol_Up',
 'troglitazone_Steady',
 'tolazamide_Steady',
 'tolazamide_Up',
 'insulin_No',
 'insulin_Steady',
 'insulin_Up',
 'glyburide-metformin_No',
 'glyburide-metformin_Steady',
 'glyburide-metformin_Up',
 'glipizide-metformin_Steady',
 'glimepiride-pioglitazone_Steady',
 'metformin-rosiglitazone_Steady',
 'metformin-pioglitazone_Steady',
 'change_No',
 'diabetesMed_Yes',
 'readmitted_YES']
In [30]:
# plotting tree with max_depth=3
dot_data = StringIO()  
export_graphviz(dt_default, out_file=dot_data,
                feature_names=features, filled=True,rounded=True)

graph = pydotplus.graph_from_dot_data(dot_data.getvalue())  
Image(graph.create_png())
Out[30]:

Tuning Hyperparameters for improving the accuracy of the model

In [33]:
# GridSearchCV to find optimal max_depth
from sklearn.model_selection import KFold
from sklearn.model_selection import GridSearchCV


# specify number of folds for k-fold CV
n_folds = 5

# parameters to build the model on
parameters = {'max_depth': range(1, 40)}

# instantiate the model
dtree = DecisionTreeClassifier(criterion = "gini", 
                               random_state = 100)

# fit tree on training data
tree = GridSearchCV(dtree, parameters, cv=n_folds,return_train_score=True,scoring="accuracy")

tree.fit(X_train, y_train)
Out[33]:
GridSearchCV(cv=5, error_score='raise-deprecating',
             estimator=DecisionTreeClassifier(class_weight=None,
                                              criterion='gini', max_depth=None,
                                              max_features=None,
                                              max_leaf_nodes=None,
                                              min_impurity_decrease=0.0,
                                              min_impurity_split=None,
                                              min_samples_leaf=1,
                                              min_samples_split=2,
                                              min_weight_fraction_leaf=0.0,
                                              presort=False, random_state=100,
                                              splitter='best'),
             iid='warn', n_jobs=None, param_grid={'max_depth': range(1, 40)},
             pre_dispatch='2*n_jobs', refit=True, return_train_score=True,
             scoring='accuracy', verbose=0)
In [34]:
# scores of GridSearch CV
scores = tree.cv_results_
pd.DataFrame(scores).head()
Out[34]:
mean_fit_time std_fit_time mean_score_time std_score_time param_max_depth params split0_test_score split1_test_score split2_test_score split3_test_score ... mean_test_score std_test_score rank_test_score split0_train_score split1_train_score split2_train_score split3_train_score split4_train_score mean_train_score std_train_score
0 0.226 0.063 0.023 0.004 1 {'max_depth': 1} 0.538 0.538 0.538 0.538 ... 0.538 0.000 31 0.538 0.538 0.538 0.538 0.538 0.538 0.000
1 0.219 0.010 0.022 0.003 2 {'max_depth': 2} 0.541 0.543 0.543 0.549 ... 0.545 0.003 22 0.546 0.546 0.546 0.544 0.544 0.545 0.001
2 0.261 0.004 0.022 0.003 3 {'max_depth': 3} 0.555 0.556 0.549 0.551 ... 0.554 0.003 17 0.559 0.559 0.551 0.551 0.558 0.556 0.004
3 0.361 0.043 0.024 0.004 4 {'max_depth': 4} 0.563 0.567 0.560 0.561 ... 0.564 0.003 10 0.571 0.569 0.561 0.559 0.569 0.566 0.005
4 0.432 0.034 0.026 0.005 5 {'max_depth': 5} 0.562 0.566 0.573 0.572 ... 0.568 0.004 8 0.572 0.571 0.572 0.572 0.572 0.572 0.000

5 rows × 21 columns

In [36]:
# plotting accuracies with max_depth
plt.figure()
plt.plot(scores["param_max_depth"], 
         scores["mean_train_score"], 
         label="training accuracy")
plt.plot(scores["param_max_depth"], 
         scores["mean_test_score"], 
         label="test accuracy")
plt.xlabel("max_depth")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

We can see from the above plot that the Decision Tree model is clearly overfitting in case of training data and consequently underperforming in case of testing data

Tuning min_samples_leaf

In [37]:
# GridSearchCV to find optimal max_depth
from sklearn.model_selection import KFold
from sklearn.model_selection import GridSearchCV


# specify number of folds for k-fold CV
n_folds = 5

# parameters to build the model on
parameters = {'min_samples_leaf': range(5, 200, 20)}

# instantiate the model
dtree = DecisionTreeClassifier(criterion = "gini", 
                               random_state = 100)

# fit tree on training data
tree = GridSearchCV(dtree, parameters, 
                    cv=n_folds, return_train_score=True,
                   scoring="accuracy")
tree.fit(X_train, y_train)
Out[37]:
GridSearchCV(cv=5, error_score='raise-deprecating',
             estimator=DecisionTreeClassifier(class_weight=None,
                                              criterion='gini', max_depth=None,
                                              max_features=None,
                                              max_leaf_nodes=None,
                                              min_impurity_decrease=0.0,
                                              min_impurity_split=None,
                                              min_samples_leaf=1,
                                              min_samples_split=2,
                                              min_weight_fraction_leaf=0.0,
                                              presort=False, random_state=100,
                                              splitter='best'),
             iid='warn', n_jobs=None,
             param_grid={'min_samples_leaf': range(5, 200, 20)},
             pre_dispatch='2*n_jobs', refit=True, return_train_score=True,
             scoring='accuracy', verbose=0)
In [38]:
# scores of GridSearch CV
scores = tree.cv_results_
pd.DataFrame(scores).head()
Out[38]:
mean_fit_time std_fit_time mean_score_time std_score_time param_min_samples_leaf params split0_test_score split1_test_score split2_test_score split3_test_score ... mean_test_score std_test_score rank_test_score split0_train_score split1_train_score split2_train_score split3_train_score split4_train_score mean_train_score std_train_score
0 2.103 0.168 0.034 0.006 5 {'min_samples_leaf': 5} 0.539 0.542 0.537 0.536 ... 0.539 0.002 10 0.817 0.818 0.817 0.817 0.816 0.817 0.001
1 1.572 0.219 0.028 0.010 25 {'min_samples_leaf': 25} 0.551 0.546 0.557 0.549 ... 0.551 0.004 9 0.670 0.671 0.670 0.670 0.669 0.670 0.001
2 1.543 0.220 0.043 0.037 45 {'min_samples_leaf': 45} 0.557 0.555 0.560 0.563 ... 0.559 0.003 8 0.640 0.639 0.639 0.640 0.639 0.639 0.000
3 1.153 0.050 0.022 0.001 65 {'min_samples_leaf': 65} 0.564 0.560 0.572 0.568 ... 0.566 0.004 7 0.625 0.626 0.624 0.625 0.625 0.625 0.001
4 1.072 0.014 0.023 0.001 85 {'min_samples_leaf': 85} 0.569 0.561 0.567 0.566 ... 0.566 0.003 6 0.617 0.617 0.616 0.616 0.615 0.616 0.001

5 rows × 21 columns

In [39]:
# plotting accuracies with min_samples_leaf
plt.figure()
plt.plot(scores["param_min_samples_leaf"], 
         scores["mean_train_score"], 
         label="training accuracy")
plt.plot(scores["param_min_samples_leaf"], 
         scores["mean_test_score"], 
         label="test accuracy")
plt.xlabel("min_samples_leaf")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

After tuning the min_samples_leaf we notice that the accuracy of the model stabilizes after the value of 150

Tuning min_samples_split

In [40]:
# GridSearchCV to find optimal min_samples_split
from sklearn.model_selection import KFold
from sklearn.model_selection import GridSearchCV


# specify number of folds for k-fold CV
n_folds = 5

# parameters to build the model on
parameters = {'min_samples_split': range(5, 200, 20)}

# instantiate the model
dtree = DecisionTreeClassifier(criterion = "gini", 
                               random_state = 100)

# fit tree on training data
tree = GridSearchCV(dtree, parameters, 
                    cv=n_folds, return_train_score=True,
                   scoring="accuracy")
tree.fit(X_train, y_train)
Out[40]:
GridSearchCV(cv=5, error_score='raise-deprecating',
             estimator=DecisionTreeClassifier(class_weight=None,
                                              criterion='gini', max_depth=None,
                                              max_features=None,
                                              max_leaf_nodes=None,
                                              min_impurity_decrease=0.0,
                                              min_impurity_split=None,
                                              min_samples_leaf=1,
                                              min_samples_split=2,
                                              min_weight_fraction_leaf=0.0,
                                              presort=False, random_state=100,
                                              splitter='best'),
             iid='warn', n_jobs=None,
             param_grid={'min_samples_split': range(5, 200, 20)},
             pre_dispatch='2*n_jobs', refit=True, return_train_score=True,
             scoring='accuracy', verbose=0)
In [42]:
# scores of GridSearch CV
scores = tree.cv_results_
pd.DataFrame(scores).head()
Out[42]:
mean_fit_time std_fit_time mean_score_time std_score_time param_min_samples_split params split0_test_score split1_test_score split2_test_score split3_test_score ... mean_test_score std_test_score rank_test_score split0_train_score split1_train_score split2_train_score split3_train_score split4_train_score mean_train_score std_train_score
0 2.132 0.394 0.049 0.043 5 {'min_samples_split': 5} 0.540 0.532 0.541 0.535 ... 0.537 0.003 10 0.950 0.951 0.951 0.950 0.951 0.951 0.001
1 1.634 0.116 0.027 0.007 25 {'min_samples_split': 25} 0.538 0.538 0.542 0.542 ... 0.541 0.003 9 0.785 0.785 0.787 0.786 0.785 0.785 0.001
2 2.052 0.174 0.030 0.008 45 {'min_samples_split': 45} 0.539 0.541 0.545 0.541 ... 0.543 0.003 8 0.733 0.732 0.732 0.735 0.730 0.733 0.002
3 1.948 0.184 0.026 0.002 65 {'min_samples_split': 65} 0.543 0.545 0.549 0.547 ... 0.547 0.002 7 0.704 0.705 0.705 0.706 0.704 0.705 0.001
4 2.084 0.330 0.028 0.009 85 {'min_samples_split': 85} 0.551 0.547 0.550 0.549 ... 0.550 0.002 6 0.687 0.686 0.688 0.687 0.686 0.687 0.001

5 rows × 21 columns

In [43]:
# plotting accuracies with min_samples_leaf
plt.figure()
plt.plot(scores["param_min_samples_split"], 
         scores["mean_train_score"], 
         label="training accuracy")
plt.plot(scores["param_min_samples_split"], 
         scores["mean_test_score"], 
         label="test accuracy")
plt.xlabel("min_samples_split")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

After tuning the min_samples_split we notice that the accuracy of the model stabilizes after the value of 150

Grid Search to Find Optimal Hyperparameters

In [54]:
# Create the parameter grid 
param_grid = {
    'max_depth': range(5, 15, 5),
    'min_samples_leaf': range(50, 200, 50),
    'min_samples_split': range(50, 200, 50),
    'criterion': ["entropy", "gini"]
}

n_folds = 5

# Instantiate the grid search model
dtree = DecisionTreeClassifier()
grid_search = GridSearchCV(estimator = dtree, param_grid = param_grid, 
                          cv = n_folds, verbose = 1,return_train_score=True)

# Fit the grid search to the data
grid_search.fit(X_train,y_train)
Fitting 5 folds for each of 36 candidates, totalling 180 fits
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 180 out of 180 | elapsed:  2.6min finished
Out[54]:
GridSearchCV(cv=5, error_score='raise-deprecating',
             estimator=DecisionTreeClassifier(class_weight=None,
                                              criterion='gini', max_depth=None,
                                              max_features=None,
                                              max_leaf_nodes=None,
                                              min_impurity_decrease=0.0,
                                              min_impurity_split=None,
                                              min_samples_leaf=1,
                                              min_samples_split=2,
                                              min_weight_fraction_leaf=0.0,
                                              presort=False, random_state=None,
                                              splitter='best'),
             iid='warn', n_jobs=None,
             param_grid={'criterion': ['entropy', 'gini'],
                         'max_depth': range(5, 15, 5),
                         'min_samples_leaf': range(50, 200, 50),
                         'min_samples_split': range(50, 200, 50)},
             pre_dispatch='2*n_jobs', refit=True, return_train_score=True,
             scoring=None, verbose=1)
In [55]:
# cv results
cv_results = pd.DataFrame(grid_search.cv_results_)
cv_results
Out[55]:
mean_fit_time std_fit_time mean_score_time std_score_time param_criterion param_max_depth param_min_samples_leaf param_min_samples_split params split0_test_score ... mean_test_score std_test_score rank_test_score split0_train_score split1_train_score split2_train_score split3_train_score split4_train_score mean_train_score std_train_score
0 0.546 0.151 0.021 0.002 entropy 5 50 50 {'criterion': 'entropy', 'max_depth': 5, 'min_... 0.562 ... 0.567 0.004 19 0.570 0.568 0.567 0.567 0.569 0.568 0.001
1 0.413 0.010 0.020 0.000 entropy 5 50 100 {'criterion': 'entropy', 'max_depth': 5, 'min_... 0.562 ... 0.567 0.004 19 0.570 0.568 0.567 0.567 0.569 0.568 0.001
2 0.422 0.034 0.020 0.001 entropy 5 50 150 {'criterion': 'entropy', 'max_depth': 5, 'min_... 0.562 ... 0.567 0.004 19 0.570 0.568 0.567 0.567 0.569 0.568 0.001
3 0.497 0.057 0.022 0.001 entropy 5 100 50 {'criterion': 'entropy', 'max_depth': 5, 'min_... 0.561 ... 0.567 0.004 31 0.569 0.567 0.567 0.567 0.567 0.567 0.001
4 0.469 0.085 0.021 0.001 entropy 5 100 100 {'criterion': 'entropy', 'max_depth': 5, 'min_... 0.561 ... 0.567 0.004 31 0.569 0.567 0.567 0.567 0.567 0.567 0.001
5 0.457 0.058 0.024 0.005 entropy 5 100 150 {'criterion': 'entropy', 'max_depth': 5, 'min_... 0.561 ... 0.567 0.004 31 0.569 0.567 0.567 0.567 0.567 0.567 0.001
6 0.451 0.051 0.029 0.009 entropy 5 150 50 {'criterion': 'entropy', 'max_depth': 5, 'min_... 0.561 ... 0.567 0.004 25 0.568 0.567 0.567 0.567 0.567 0.567 0.000
7 0.559 0.098 0.027 0.009 entropy 5 150 100 {'criterion': 'entropy', 'max_depth': 5, 'min_... 0.561 ... 0.567 0.004 25 0.568 0.567 0.567 0.567 0.567 0.567 0.000
8 0.516 0.120 0.026 0.007 entropy 5 150 150 {'criterion': 'entropy', 'max_depth': 5, 'min_... 0.561 ... 0.567 0.004 25 0.568 0.567 0.567 0.567 0.567 0.567 0.000
9 0.758 0.047 0.025 0.005 entropy 10 50 50 {'criterion': 'entropy', 'max_depth': 10, 'min... 0.567 ... 0.571 0.003 8 0.584 0.582 0.581 0.579 0.583 0.582 0.002
10 1.075 0.103 0.035 0.010 entropy 10 50 100 {'criterion': 'entropy', 'max_depth': 10, 'min... 0.567 ... 0.571 0.003 8 0.584 0.582 0.581 0.579 0.583 0.582 0.002
11 0.767 0.112 0.027 0.008 entropy 10 50 150 {'criterion': 'entropy', 'max_depth': 10, 'min... 0.568 ... 0.572 0.003 2 0.583 0.581 0.580 0.579 0.583 0.581 0.002
12 0.773 0.066 0.023 0.004 entropy 10 100 50 {'criterion': 'entropy', 'max_depth': 10, 'min... 0.569 ... 0.571 0.002 16 0.584 0.582 0.582 0.578 0.583 0.582 0.002
13 0.923 0.148 0.023 0.003 entropy 10 100 100 {'criterion': 'entropy', 'max_depth': 10, 'min... 0.569 ... 0.571 0.002 16 0.584 0.582 0.582 0.578 0.583 0.582 0.002
14 0.855 0.225 0.024 0.004 entropy 10 100 150 {'criterion': 'entropy', 'max_depth': 10, 'min... 0.569 ... 0.571 0.002 16 0.584 0.582 0.582 0.578 0.583 0.582 0.002
15 0.939 0.150 0.025 0.006 entropy 10 150 50 {'criterion': 'entropy', 'max_depth': 10, 'min... 0.568 ... 0.571 0.002 8 0.582 0.582 0.581 0.580 0.581 0.581 0.001
16 1.407 0.847 0.035 0.013 entropy 10 150 100 {'criterion': 'entropy', 'max_depth': 10, 'min... 0.568 ... 0.571 0.002 8 0.582 0.582 0.581 0.580 0.581 0.581 0.001
17 1.658 0.519 0.038 0.006 entropy 10 150 150 {'criterion': 'entropy', 'max_depth': 10, 'min... 0.568 ... 0.571 0.002 8 0.582 0.582 0.581 0.580 0.581 0.581 0.001
18 1.277 0.961 0.034 0.021 gini 5 50 50 {'criterion': 'gini', 'max_depth': 5, 'min_sam... 0.562 ... 0.567 0.004 22 0.570 0.568 0.567 0.567 0.569 0.568 0.001
19 0.583 0.069 0.032 0.006 gini 5 50 100 {'criterion': 'gini', 'max_depth': 5, 'min_sam... 0.562 ... 0.567 0.004 22 0.570 0.568 0.567 0.567 0.569 0.568 0.001
20 0.642 0.124 0.044 0.022 gini 5 50 150 {'criterion': 'gini', 'max_depth': 5, 'min_sam... 0.562 ... 0.567 0.004 22 0.570 0.568 0.567 0.567 0.569 0.568 0.001
21 0.428 0.025 0.021 0.001 gini 5 100 50 {'criterion': 'gini', 'max_depth': 5, 'min_sam... 0.561 ... 0.566 0.004 34 0.569 0.567 0.567 0.567 0.568 0.567 0.001
22 0.489 0.062 0.024 0.005 gini 5 100 100 {'criterion': 'gini', 'max_depth': 5, 'min_sam... 0.561 ... 0.566 0.004 34 0.569 0.567 0.567 0.567 0.568 0.567 0.001
23 0.519 0.058 0.030 0.011 gini 5 100 150 {'criterion': 'gini', 'max_depth': 5, 'min_sam... 0.561 ... 0.566 0.004 34 0.569 0.567 0.567 0.567 0.568 0.567 0.001
24 0.731 0.109 0.041 0.020 gini 5 150 50 {'criterion': 'gini', 'max_depth': 5, 'min_sam... 0.561 ... 0.567 0.004 25 0.568 0.567 0.567 0.567 0.567 0.567 0.000
25 0.600 0.129 0.025 0.004 gini 5 150 100 {'criterion': 'gini', 'max_depth': 5, 'min_sam... 0.561 ... 0.567 0.004 25 0.568 0.567 0.567 0.567 0.567 0.567 0.000
26 0.519 0.080 0.032 0.020 gini 5 150 150 {'criterion': 'gini', 'max_depth': 5, 'min_sam... 0.561 ... 0.567 0.004 25 0.568 0.567 0.567 0.567 0.567 0.567 0.000
27 0.774 0.049 0.023 0.002 gini 10 50 50 {'criterion': 'gini', 'max_depth': 10, 'min_sa... 0.567 ... 0.571 0.003 6 0.584 0.583 0.581 0.579 0.584 0.582 0.002
28 0.966 0.146 0.028 0.005 gini 10 50 100 {'criterion': 'gini', 'max_depth': 10, 'min_sa... 0.567 ... 0.571 0.003 6 0.584 0.583 0.581 0.579 0.584 0.582 0.002
29 0.792 0.170 0.025 0.009 gini 10 50 150 {'criterion': 'gini', 'max_depth': 10, 'min_sa... 0.568 ... 0.572 0.003 1 0.583 0.583 0.581 0.579 0.584 0.582 0.002
30 1.059 0.136 0.032 0.012 gini 10 100 50 {'criterion': 'gini', 'max_depth': 10, 'min_sa... 0.569 ... 0.572 0.002 3 0.584 0.582 0.582 0.580 0.583 0.582 0.001
31 1.079 0.095 0.028 0.006 gini 10 100 100 {'criterion': 'gini', 'max_depth': 10, 'min_sa... 0.569 ... 0.572 0.002 3 0.584 0.582 0.582 0.580 0.583 0.582 0.001
32 0.700 0.027 0.021 0.001 gini 10 100 150 {'criterion': 'gini', 'max_depth': 10, 'min_sa... 0.569 ... 0.572 0.002 3 0.584 0.582 0.582 0.580 0.583 0.582 0.001
33 0.679 0.006 0.021 0.001 gini 10 150 50 {'criterion': 'gini', 'max_depth': 10, 'min_sa... 0.568 ... 0.571 0.002 13 0.582 0.582 0.581 0.581 0.581 0.581 0.000
34 1.076 0.405 0.025 0.006 gini 10 150 100 {'criterion': 'gini', 'max_depth': 10, 'min_sa... 0.568 ... 0.571 0.002 13 0.582 0.582 0.581 0.581 0.581 0.581 0.000
35 0.959 0.196 0.027 0.008 gini 10 150 150 {'criterion': 'gini', 'max_depth': 10, 'min_sa... 0.568 ... 0.571 0.002 13 0.582 0.582 0.581 0.581 0.581 0.581 0.000

36 rows × 24 columns

In [56]:
# printing the optimal accuracy score and hyperparameters
print("best accuracy", grid_search.best_score_)
print(grid_search.best_estimator_)
best accuracy 0.5720041554357592
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=10,
                       max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=50, min_samples_split=150,
                       min_weight_fraction_leaf=0.0, presort=False,
                       random_state=None, splitter='best')
In [57]:
# model with optimal hyperparameters
clf_gini = DecisionTreeClassifier(criterion = "gini", 
                                  random_state = 100,
                                  max_depth=10, 
                                  min_samples_leaf=50,
                                  min_samples_split=50)
clf_gini.fit(X_train, y_train)
Out[57]:
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=10,
                       max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=50, min_samples_split=50,
                       min_weight_fraction_leaf=0.0, presort=False,
                       random_state=100, splitter='best')
In [58]:
# accuracy score
clf_gini.score(X_test,y_test)
Out[58]:
0.5713115828092243
In [59]:
# plotting the tree
dot_data = StringIO()  
export_graphviz(clf_gini, out_file=dot_data,feature_names=features,filled=True,rounded=True)

graph = pydotplus.graph_from_dot_data(dot_data.getvalue())  
Image(graph.create_png())
Out[59]:
In [61]:
# tree with max_depth = 3
clf_gini = DecisionTreeClassifier(criterion = "gini", 
                                  random_state = 100,
                                  max_depth=3, 
                                  min_samples_leaf=50,
                                  min_samples_split=50)
clf_gini.fit(X_train, y_train)

# score
print(clf_gini.score(X_test,y_test))
0.5528039832285115
In [62]:
# plotting tree with max_depth=3
dot_data = StringIO()  
export_graphviz(clf_gini, out_file=dot_data,feature_names=features,filled=True,rounded=True)

graph = pydotplus.graph_from_dot_data(dot_data.getvalue())  
Image(graph.create_png())
Out[62]:
In [63]:
# classification metrics
from sklearn.metrics import classification_report,confusion_matrix
y_pred = clf_gini.predict(X_test)
print(classification_report(y_test, y_pred))
              precision    recall  f1-score   support

           0       0.55      0.95      0.70     16553
           1       0.58      0.08      0.15     13975

    accuracy                           0.55     30528
   macro avg       0.57      0.52      0.42     30528
weighted avg       0.56      0.55      0.44     30528

In [64]:
# confusion matrix
print(confusion_matrix(y_test,y_pred))
[[15705   848]
 [12804  1171]]

KNN model training

In [24]:
from sklearn.neighbors import KNeighborsClassifier

# Train the KNN classifier for 2 neighbours
classifier = KNeighborsClassifier(n_neighbors=2)
classifier.fit(X_train, y_train)
Out[24]:
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
                     metric_params=None, n_jobs=None, n_neighbors=2, p=2,
                     weights='uniform')
In [25]:
# Use the classifier to predict the values from the X_test dataset
y_pred = classifier.predict(X_test)
In [26]:
# As an objective measure, we use the confusion matrix to see how the model performed

from sklearn.metrics import classification_report, confusion_matrix
print(confusion_matrix(y_test, y_pred))
print(classification_report(y_test, y_pred))
[[13003  3550]
 [10337  3638]]
              precision    recall  f1-score   support

           0       0.56      0.79      0.65     16553
           1       0.51      0.26      0.34     13975

    accuracy                           0.55     30528
   macro avg       0.53      0.52      0.50     30528
weighted avg       0.53      0.55      0.51     30528

Plotting the ROC curve to a visual analysis of the model's performance

In [29]:
from sklearn.metrics import roc_curve
from sklearn.metrics import auc

# Find the confidence level of each prediction
knn_pred_prob = classifier.predict_proba(X_test)

fpr, tpr, threshold = roc_curve(y_test, knn_pred_prob[:, 1])
roc_auc = auc(fpr, tpr)

plt.title('Receiver Operating Characteristic')
plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
plt.legend(loc = 'lower right')
plt.plot([0, 1], [0, 1],'r--')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.title('ROC Curve of kNN')
plt.show()

KNN - Parameter tuning

Iterating over a range of values for n_neighbours to obtain an optimum value

In [32]:
from sklearn.neighbors import KNeighborsClassifier


error = []

# Calculating error for K values between 1 and 10
for i in range(1, 10):
    print("Currently training with : ", i)
    knn = KNeighborsClassifier(n_neighbors=i)
    knn.fit(X_train, y_train)
    pred_i = knn.predict(X_test)
    error.append(np.mean(pred_i != y_test))
    
    
error
Currently training with :  1
Currently training with :  2
Currently training with :  3
Currently training with :  4
Currently training with :  5
Currently training with :  6
Currently training with :  7
Currently training with :  8
Currently training with :  9
Out[32]:
[0.4677017819706499,
 0.4548938679245283,
 0.4632796121593291,
 0.45384564989517817,
 0.4545662997903564,
 0.4473270440251572,
 0.44945623689727465,
 0.44483752620545075,
 0.4502751572327044]
In [33]:
plt.figure(figsize=(12, 6))
plt.plot(range(1, 10), error, color='red', linestyle='dashed', marker='o',
         markerfacecolor='blue', markersize=10)
plt.title('Error Rate K Value')
plt.xlabel('K Value')
plt.ylabel('Mean Error')
Out[33]:
Text(0, 0.5, 'Mean Error')

Using the optimum value of neighbours from the above analysis

In [34]:
classifier = KNeighborsClassifier(n_neighbors=8, 
                                  algorithm='kd_tree')

classifier.fit(X_train, y_train)
Out[34]:
KNeighborsClassifier(algorithm='kd_tree', leaf_size=30, metric='minkowski',
                     metric_params=None, n_jobs=None, n_neighbors=8, p=2,
                     weights='uniform')
In [37]:
y_pred = classifier.predict(X_test)
In [39]:
from sklearn.metrics import classification_report, confusion_matrix
print(confusion_matrix(y_test, y_pred))
print(classification_report(y_test, y_pred))
[[11679  4874]
 [ 8706  5269]]
              precision    recall  f1-score   support

           0       0.57      0.71      0.63     16553
           1       0.52      0.38      0.44     13975

    accuracy                           0.56     30528
   macro avg       0.55      0.54      0.53     30528
weighted avg       0.55      0.56      0.54     30528

In [40]:
from sklearn.metrics import roc_curve
from sklearn.metrics import auc

knn_pred_prob = classifier.predict_proba(X_test)

fpr, tpr, threshold = roc_curve(y_test, knn_pred_prob[:, 1])
roc_auc = auc(fpr, tpr)

plt.title('Receiver Operating Characteristic')
plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
plt.legend(loc = 'lower right')
plt.plot([0, 1], [0, 1],'r--')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.title('ROC Curve of kNN')
plt.show()

Chosen Classification Model

Decision Tree

Given the domain and the close coordination needed with the Doctors/ Medical personnel, I think it's better to go with a model with higher explicability such that the analysis as well as the Medical Personnel is completely aware of why and how the model arrived at a specific conclusion.

Population stratification

Use trained model to stratify your population into 3 risk buckets:

  • High risk (Probability of readmission >0.7)
  • Medium risk (0.3 < Probability of readmission < 0.7)
  • Low risk (Probability of readmission < 0.3)
In [41]:
# Here are the predicted probability using the Decision Tree model
dt_pred_prob
Out[41]:
array([[0.62029746, 0.37970254],
       [0.46863469, 0.53136531],
       [0.51388407, 0.48611593],
       ...,
       [0.46863469, 0.53136531],
       [0.57825046, 0.42174954],
       [0.54361294, 0.45638706]])
In [42]:
stratification_array = []

for a_confidence_arr in dt_pred_prob:
    risk = a_confidence_arr[0]
    if  ( risk >= 0.7 ):
        stratification_array.append('High')
    elif ( 0.3 <= risk < 0.7 ):
        stratification_array.append('Medium')
    else:
        stratification_array.append('Low')
In [43]:
# Add the `confidence_YES` column to the dataframe and assign it the inferred stratification values
X_train['confidence_YES'] = stratification_array
X_train
Out[43]:
discharge_disposition_id admission_source_id time_in_hospital num_lab_procedures num_procedures num_medications number_outpatient number_emergency number_inpatient number_diagnoses ... glyburide-metformin_No glyburide-metformin_Steady glyburide-metformin_Up glipizide-metformin_Steady glimepiride-pioglitazone_Steady metformin-rosiglitazone_Steady metformin-pioglitazone_Steady change_No diabetesMed_Yes confidence_YES
6912 6 4 -0.133 -0.310 -0.785 -0.864 -0.291 -0.213 -0.503 -0.219 ... 1 0 0 0 0 0 0 0 1 Medium
39101 3 7 2.547 0.605 0.973 -0.618 0.498 -0.213 -0.503 -0.736 ... 1 0 0 0 0 0 0 0 1 Medium
101401 1 7 -0.133 -0.462 -0.785 -0.126 -0.291 -0.213 1.872 0.816 ... 1 0 0 0 0 0 0 0 1 Medium
56509 1 1 0.537 -0.462 -0.785 -0.987 2.076 -0.213 -0.503 0.816 ... 1 0 0 0 0 0 0 0 1 Medium
56563 1 7 -0.468 0.046 -0.785 -0.126 0.498 -0.213 1.080 -0.219 ... 1 0 0 0 0 0 0 1 0 Medium
36588 1 1 -0.803 -0.005 -0.785 0.243 -0.291 -0.213 -0.503 -2.287 ... 1 0 0 0 0 0 0 0 1 Medium
75223 1 7 0.872 0.453 -0.785 0.120 2.865 -0.213 -0.503 0.816 ... 1 0 0 0 0 0 0 1 0 Medium
24203 6 7 0.202 -1.326 -0.199 -0.987 -0.291 -0.213 0.289 0.299 ... 1 0 0 0 0 0 0 1 1 Medium
91829 1 7 -0.468 0.859 -0.785 1.474 -0.291 -0.213 1.872 0.299 ... 1 0 0 0 0 0 0 1 0 Medium
78719 6 7 -0.803 1.113 -0.199 -0.864 -0.291 -0.213 1.080 0.816 ... 1 0 0 0 0 0 0 0 1 Medium
39414 1 4 0.537 0.859 -0.785 -0.864 -0.291 -0.213 -0.503 0.816 ... 1 0 0 0 0 0 0 1 1 High
27431 1 7 -0.468 0.503 -0.785 -0.372 -0.291 -0.213 1.080 0.816 ... 1 0 0 0 0 0 0 0 1 Medium
68451 1 7 -0.468 -0.767 -0.785 -0.618 0.498 0.862 -0.503 0.816 ... 1 0 0 0 0 0 0 1 0 Medium
64148 6 1 -0.803 1.164 -0.785 -0.249 1.287 1.937 0.289 0.816 ... 1 0 0 0 0 0 0 0 1 Medium
77933 3 1 -0.468 0.605 -0.785 -0.864 2.076 -0.213 1.080 0.816 ... 1 0 0 0 0 0 0 0 1 Medium
60940 1 7 -1.138 -1.631 0.973 -0.372 -0.291 -0.213 0.289 0.816 ... 0 1 0 0 0 0 0 0 1 Medium
26718 1 6 -1.138 -0.107 2.146 0.366 -0.291 0.862 0.289 0.816 ... 1 0 0 0 0 0 0 1 0 High
66076 1 1 -0.468 0.097 -0.785 0.489 -0.291 -0.213 -0.503 0.816 ... 1 0 0 0 0 0 0 0 1 Medium
46677 5 1 -0.133 -0.411 -0.199 1.474 1.287 -0.213 0.289 -0.219 ... 1 0 0 0 0 0 0 1 0 Medium
71263 3 7 2.212 1.774 2.732 2.827 -0.291 -0.213 -0.503 -0.219 ... 1 0 0 0 0 0 0 0 1 Medium
39269 1 1 -0.803 1.215 -0.785 -1.233 -0.291 -0.213 -0.503 -1.770 ... 1 0 0 0 0 0 0 0 1 Medium
79285 3 1 -0.803 -0.818 -0.199 -0.618 -0.291 -0.213 -0.503 0.816 ... 1 0 0 0 0 0 0 1 0 Medium
101250 3 7 -0.133 0.503 0.973 1.474 -0.291 -0.213 -0.503 0.816 ... 1 0 0 0 0 0 0 1 1 Medium
59920 1 7 -1.138 -1.123 0.387 -0.741 2.076 -0.213 0.289 0.816 ... 1 0 0 0 0 0 0 1 1 Medium
63621 3 7 -1.138 -0.513 0.973 -0.249 -0.291 -0.213 -0.503 0.816 ... 1 0 0 0 0 0 0 1 1 Medium
79700 6 1 -1.138 -0.411 -0.785 -0.741 -0.291 0.862 -0.503 -0.219 ... 1 0 0 0 0 0 0 1 0 Medium
6593 6 7 -0.803 0.707 -0.785 -1.110 -0.291 -0.213 -0.503 -0.736 ... 1 0 0 0 0 0 0 1 1 Medium
56872 4 1 -0.468 1.469 -0.199 -0.618 0.498 -0.213 -0.503 0.816 ... 1 0 0 0 0 0 0 1 1 Medium
85728 22 7 -0.803 1.113 -0.785 0.120 1.287 -0.213 -0.503 0.299 ... 1 0 0 0 0 0 0 0 1 Medium
60713 1 7 -0.133 0.656 -0.785 -1.356 0.498 4.086 -0.503 0.816 ... 1 0 0 0 0 0 0 0 1 Medium
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
21379 1 7 2.212 0.605 0.387 -0.618 -0.291 -0.213 -0.503 0.816 ... 1 0 0 0 0 0 0 1 0 Medium
27041 7 7 -0.468 -1.326 -0.785 0.736 0.498 1.937 1.080 0.816 ... 1 0 0 0 0 0 0 1 0 Medium
30598 22 7 -0.803 0.097 -0.785 -0.495 -0.291 -0.213 1.872 -1.253 ... 1 0 0 0 0 0 0 0 1 Medium
83823 6 7 -0.468 -0.310 -0.785 -0.126 -0.291 -0.213 1.872 0.816 ... 0 1 0 0 0 0 0 0 1 Medium
71326 1 1 -0.803 -0.005 0.387 -0.864 -0.291 -0.213 -0.503 -0.219 ... 1 0 0 0 0 0 0 0 1 Medium
84357 22 7 -0.133 0.198 -0.785 0.243 -0.291 -0.213 0.289 0.816 ... 0 1 0 0 0 0 0 0 1 Medium
6231 1 7 -0.803 -1.733 -0.199 -0.003 -0.291 -0.213 -0.503 -2.805 ... 1 0 0 0 0 0 0 1 0 Medium
80992 6 1 0.537 1.012 0.973 1.351 -0.291 -0.213 -0.503 0.816 ... 1 0 0 0 0 0 0 0 1 Medium
38539 1 7 -0.468 0.148 -0.785 -0.987 1.287 -0.213 -0.503 -0.219 ... 1 0 0 0 0 0 0 0 1 Medium
42493 1 7 0.537 0.148 -0.785 1.228 -0.291 -0.213 0.289 0.816 ... 1 0 0 0 0 0 0 1 1 Medium
7739 1 7 1.877 0.503 0.387 1.105 -0.291 -0.213 0.289 0.816 ... 1 0 0 0 0 0 0 0 1 Medium
17726 1 1 -0.468 0.402 0.973 -0.987 -0.291 -0.213 -0.503 0.816 ... 1 0 0 0 0 0 0 0 1 Medium
59245 1 1 -0.133 0.249 -0.785 0.243 -0.291 -0.213 0.289 0.816 ... 1 0 0 0 0 0 0 0 1 Medium
91794 1 7 -1.138 -0.310 -0.785 -1.602 -0.291 -0.213 0.289 0.816 ... 1 0 0 0 0 0 0 1 1 Medium
86299 3 7 -0.468 0.453 -0.199 -0.003 -0.291 -0.213 -0.503 0.816 ... 1 0 0 0 0 0 0 0 1 Medium
78836 6 7 -1.138 0.148 0.973 -0.864 0.498 -0.213 -0.503 0.816 ... 1 0 0 0 0 0 0 1 1 Medium
49956 1 7 -0.803 -0.310 -0.785 -0.618 -0.291 -0.213 -0.503 -0.736 ... 1 0 0 0 0 0 0 1 1 Medium
71185 6 7 1.542 0.046 0.387 1.474 -0.291 -0.213 -0.503 0.816 ... 1 0 0 0 0 0 0 1 1 Medium
13891 18 7 1.877 0.453 0.973 1.228 -0.291 -0.213 0.289 -1.253 ... 1 0 0 0 0 0 0 1 0 Medium
20535 11 7 1.542 0.148 -0.199 0.366 -0.291 -0.213 1.080 0.816 ... 1 0 0 0 0 0 0 1 1 High
75621 3 1 2.212 1.876 2.732 3.935 -0.291 -0.213 -0.503 0.816 ... 1 0 0 0 0 0 0 0 1 Medium
14261 3 7 0.872 1.571 -0.199 0.859 0.498 -0.213 0.289 -0.219 ... 1 0 0 0 0 0 0 1 1 Medium
82273 1 7 -0.803 -2.038 -0.199 -0.864 -0.291 -0.213 -0.503 0.816 ... 1 0 0 0 0 0 0 0 1 Medium
63373 23 7 0.872 1.367 2.732 1.105 -0.291 -0.213 3.456 0.816 ... 1 0 0 0 0 0 0 0 1 Medium
81843 1 1 -0.803 -1.225 2.146 0.859 -0.291 -0.213 0.289 0.299 ... 1 0 0 0 0 0 0 0 1 Medium
65618 1 7 -1.138 -0.005 -0.785 -0.741 -0.291 -0.213 0.289 0.816 ... 1 0 0 0 0 0 0 1 1 Medium
77658 3 5 0.537 1.266 -0.785 0.489 1.287 0.862 0.289 0.816 ... 1 0 0 0 0 0 0 0 1 Medium
79686 3 7 1.207 0.758 2.732 1.720 1.287 -0.213 2.664 0.816 ... 1 0 0 0 0 0 0 1 1 Medium
56090 1 7 0.202 1.012 -0.785 -0.372 -0.291 -0.213 -0.503 0.816 ... 1 0 0 0 0 0 0 1 0 Medium
38410 1 1 -0.468 0.198 -0.199 0.120 1.287 -0.213 -0.503 -2.287 ... 1 0 0 0 0 0 0 0 1 Medium

71232 rows × 161 columns